%Model code for 'Expectations, Infections, and Economic Activity' by
%M. Eichenbaum, M. Godinho de Matos, F. Lima, S. Rebelo and M. Trabandt

%Fall 2022. mathias.trabandt@gmail.com

close all; clear all; clc;

rng('default');

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%set switches for desired calculations
mpar.do_estimation=0;

mpar.start_mode_finder=1;
mpar.load_last_estimation_step=0;
mpar.est_maxiter=1000;
mpar.do_checkplots=0;
mpar.grdpts=5;%checkplot must be odd number of gridpoints
mpar.use_parallel=1;

mpar.do_mcmc=0;
mpar.do_load_mcmc_results=0;


%Note that young people have less assets than old people. Set number of
%grid points below young and above old people assets ob asset grid.
mpar.nbsim_below=1;%number of grid points (assets) below young people asset grid point
mpar.nbsim_above=1;%number of grid points (assets) above old people asset grid point
%for epi simulation (subset of mpar.nb);
%using a subset of grid points for the simulation speeds up computations
%note that a log-spaced grid for assets is used below. So, you may want to
%use more grid points below young assets than above old assets. With
%the baseline spec below, young asset grid is 133, old asset grid is 140 and median asset index is 137 on assets grid.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

mpar.infect_surprise=1;%1; %if =1, surprise second wave
mpar.infect_surprise_week=17; %week when first wave ends; next wave starts in mpar.infect_surprise_week+1 (unexpectedly)

mpar.timestart=5; %start of simulation (week since March 1)

fminbnd_options=optimset('Display','off','MaxFunEvals',100,'TolX',1e-12);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%Parameters
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%% Define numerical parameters
mpar.nb   = 150;        %number of grid points, assets
mpar.minb = 50;         %lowest point on the asset grid; chosen such that log grid for b is close to 75k (median assets in data)
mpar.maxb = 150000;     %highest point on the asset grid

mpar.crit_run1=1e-7;    %precision of value function iteration

mpar.do_sol_sim=1;      %solve and simulate model or load from disk otherwise
%be careful: solving the model can take substantial
%time (value function iteration)

mpar.loadVini=1;        %use old solution as initial guess when solving the model using value function iteration for i=0 case

mpar.show_valfun_iter=0; %=1 then show change in value functions in command window.

mpar.quick_calib=0; %=1, use for calibration. Reduce asset grid and accuracy to compute value functions for i=0 case
mpar.quick_calib_precision=1e-3; %reduce precision when going for a quick calibration and solution
mpar.run_calib_matching=0;  %use a numerical routine (optimizer, Netwon-based) to find parameters to match data targets

%Parameters for SIR model calibration
mpar.opts_fsolve=optimoptions('fsolve','Display','off','TolFun',1e-7,'MaxFunctionEvaluations',100000,'MaxIterations',4000); %options for fsolve
mpar.HH=200;            %number of periods in SIR model simulation

mpar.scale1=1000;       %scale pi1 for numerical solver

%% Parameters
par.betta=0.97^(1/52);  %discount factor
par.deltao=1/(13*52);   %prob of dying naturally old (average life expectancy old...)
par.deltay=1/(51*52);   %prob of dying naturally young
par.nu = 1/(28*52);     %prob. of becoming old/aging

par.R0=2.5;             %R0, initial basic reproduction number


%Initial values of parameters that are to be estimated
par.pidy_ini=0.089052852121211;          %initial value const. gain learning, young
par.pido_ini=0.428188485888119;          %initial value const. gain learning, old
par.gyoung=0.069289753446703;            %const. gain param, young
par.gold=0.091977911672371;              %const. gain param, old

par.muvec_scale=0.525754629721827;       %scaling factor containment measure
par.kappa=0.069099320356513;             %kappa, average share of consumption-based infections



par.sy_shr=0.7;  %share of young population

par.alf=2;      %EZ risk aversion (Albuquerque, Eichenbaum, Luo, Rebelo, benchmark model)
par.rho=1/1.5;  %Inverse IES (Albuquerque, Eichenbaum, Luo, Rebelo, benchmark model)

par.r=1.01^(1/52)-1; %real interest rate

par.totinc=19000; %total after tax income, per annum

%further utility parameters that are set to match targets below
par.z=2.663;       %constant in utility function (outside EZ aggregator)
par.omega0=159.512;%intercept in bequest utility;
par.omega1=4.883;  %slope in bequest utility (in terms of assets)

%data targets:
%VoL: 890.000
%ratio of cons young/old, 1.18
%agg savings rate 6.7%


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%Mortality rates, Infections and containment     %%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

get_ifr_data; %get infection-fatality-rate data from Sorensen et al (2022).

par.days=14;  %number of days to on average recover or die

%calculate probabilities of dying or recovering as a function of cfrs
par.pidy_vec_true=7/par.days*0.001*(ifr_data_weekly.trend);%Weekly probability of dying, young. Apply time trend from Sorensen et al (Lancet, 2022)
par.piry_vec_true=7*1/par.days-par.pidy_vec_true;           %Weekly probability of recovering, young
par.pido_vec_true=7/par.days*0.035*(ifr_data_weekly.trend); %Weekly probability of dying, old. Apply time trend from Sorensen et al (Lancet, 2022)
par.piro_vec_true=7*1/par.days-par.pido_vec_true;           %Weekly probability of recovering, old

%get weekly new deaths and calculate implied infections
get_weeklynewdeaths;

%aggregate pid
agg_pid=par.sy_shr*par.pidy_vec_true+(1-par.sy_shr)*par.pido_vec_true;

%weekly infections
Ivec_RAW=[newdeaths_weekly.NewDeaths(2:end)]./agg_pid(1:end-1)/10000000; %weekly infections as share of initial population; use aggregate SIR equation for deaths to compute infections
Ivec=[Ivec_RAW',0];       %add zero as last observation to solve the model with perfect foresight and given the I=0 value functions as a start for backward induction
Ivec=Ivec(mpar.timestart:end); %start simulation at week mpar.timestart

%get containment measure
get_contain;
muvec=Ivec*0;
muvec(1:end)=contain_weekly.contain(mpar.timestart:end);

%start pi's at correct start date
par.pidy_vec_true=par.pidy_vec_true(mpar.timestart:end);%Weekly probability of dying, young. Apply time trend from Sorensen et al (Lancet, 2022)
par.piry_vec_true=par.piry_vec_true(mpar.timestart:end);%Weekly probability of recovering, young
par.pido_vec_true=par.pido_vec_true(mpar.timestart:end);%Weekly probability of dying, old. Apply time trend from Sorensen et al (Lancet, 2022)
par.piro_vec_true=par.piro_vec_true(mpar.timestart:end);%Weekly probability of recovering, old

%initialize pi1 and pi2 -- note that these parameters will be overwritten
%later on based on kappa and R0 estimated values in simulate_model.m
par.pi1=1.699874013744988e-04; %Transmission-function: consumption-based infections
par.pi2=1.192123295885852;     %Transmission-function: general infections

%get initial infections in the data
par.i_ini=Ivec(mpar.timestart+1);        %Initial seed of total infections for basic SIR model; Start in week 5 of march 2020 in line with treatment of infections above
if par.i_ini==0,disp('Initial infections data are zero. Setting it to 0.1%'); par.i_ini=0.001;end

%%Setting model and structures for value function iteration and simulation
%% Grid for assets
grid.b = exp(linspace(log(mpar.minb),log(mpar.maxb),mpar.nb));

%when going for a quick run on the results and calibration, use a smaller
%grid and less precision.
if mpar.quick_calib==1
    grid.b=grid.b(136:1:138);
    mpar.nb=numel(grid.b);
    mpar.crit_run1=mpar.quick_calib_precision;    %precision of value function iteration
end

%assets young, old, median
[~,par.grid_b_median_assets_idx]=min(abs(grid.b-75000)); %get index on b grid that is closest to median assets, 75k
[~,par.grid_b_young_assets_idx]=min(abs(grid.b-75000)); %get index on b grid that is closest to assets of young people, 75k
[~,par.grid_b_old_assets_idx]=min(abs(grid.b-75000));   %get index on b grid that is closest to assets of old people, 75k

par.median_assets=grid.b(par.grid_b_median_assets_idx); %median assets on b grid
par.young_assets=grid.b(par.grid_b_young_assets_idx); %young assets on b grid
par.old_assets=grid.b(par.grid_b_old_assets_idx);  %old assets on b grid

par.intinc_old=((1+par.r)^52-1)*par.old_assets;%interest income, per annum, old
par.intinc_young=((1+par.r)^52-1)*par.young_assets;%interest income, per annum, young
par.intinc_avg=par.sy_shr*par.intinc_young+(1-par.sy_shr)*par.intinc_old; %weighted average interest income

%grid points for epi simulation and model estimation
mpar.grid_sim_idx=[par.grid_b_young_assets_idx-mpar.nbsim_below:1:par.grid_b_old_assets_idx+mpar.nbsim_above]; %indexes for simulation subset of grid points

%labor income
par.w=(par.totinc-par.intinc_avg)/52;%weekly net of taxes labor income per year (lab inc. + int. inc =19k)

%% Calculate available resources to be split between consumption and asset savings later on
%Note that Y below is not net income (wage+interest income) but rather available resources
%based on wage inc. and asset interest plus asset principal repayment
[meshes.b]= ndgrid(grid.b);
Y=par.w+(1+par.r)*meshes.b;

par.lowinc_ann=52*(par.w+par.r*meshes.b(1)); %annual income (wage income plus net interest income), lowest and highest income
par.highinc_ann=52*(par.w+par.r*meshes.b(end));

%% Initialize array to store value functions by types
U=[];

%%Definitions of flags that will be used in code later on
%flag=1 - ro
%flag=2 - ry
%flag=3 - io
%flag=4 - iy
%flag=5 - so
%flag=6 - sy

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%SOLVE VALUE FUNCTION FOR I=0 for all t     %%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%set pid's and pir's to end-of sample values
par.pidy=par.pidy_vec_true(end);%Weekly probability of dying, young
par.piry=par.piry_vec_true(end);%Weekly probability of recovering, young
par.pido=par.pido_vec_true(end);%Weekly probability of dying, old
par.piro=par.piro_vec_true(end);%Weekly probability of recovering, old


if mpar.run_calib_matching==1
    %calibration of z, omega0, omega1
    disp(' ');
    disp('Calibrating z, omega0, omega1: Solving Value Functions (this may take a while)');
    disp('------------------------------------------------------------------------------');
    
    calib_guess=[par.z; par.omega0/100;par.omega1];
    
    calib_guess= [  2.666778190334839
        1.583329999508420
        4.894318211765481];
    
    load V_ini;global V_ini;
    
    [sol,fff,exitflag] = fsolve(@calibrate_params,calib_guess,optimoptions('fsolve','Display','iter','MaxFunctionEvaluations',100,'UseParallel',false),par,mpar,Y,grid,U,fminbnd_options);
    
    format short;sol
    format long
    
    par.z=sol(1);
    par.omega0=sol(2)*100;
    par.omega1=sol(3);
    
    [fff,avg_VoL,avg_savings_rate,cry_cro_ratio]=calibrate_params(sol,par,mpar,Y,grid,U,fminbnd_options);
    
    disp(['avg VoL, mill, initial assets    = ',num2str(avg_VoL)]);
    disp(['cry_cro_ratio, initial assets    = ',num2str(cry_cro_ratio)]);
    disp(['avg savings_rate, ini assets     = ',num2str(avg_savings_rate)]);
    disp(' ');
    
end


disp(' ');
disp('Solving Value Functions (this may take a while)');
disp('-----------------------------------------------');
if mpar.do_sol_sim==1
    tstart=tic;
    V_ini=2000+zeros(mpar.nb,6);
    for flag=1:1:6
        if mpar.loadVini==1
            load V_ini;
        end
        if flag==5
            V=U.ro;
        elseif flag==6
            V=U.ry;
        else
            V=V_ini(:,flag);
        end
        I=0;mu=0;solflag=1;
        dist=9999;count=1;tic;
        dist_V = @(V)(V-VFI_update_spline(V,Y,par,mpar,grid,U,flag,solflag,I,mu,[],[],[],[],[],[],fminbnd_options));
        while dist(count)>mpar.crit_run1
            count = count+1;
            DV    = dist_V(V);
            dist(count) = max(abs(DV(:))); % Calculate distance between old guess and update
            V     = V-DV;
            if mpar.show_valfun_iter==1
                disp(dist(count))
            end
        end
        Vsol=V;
        %mpar.crit = 1e-3;       %precision up to which to solve the value function with broyden
        %mpar.broydencritx=1e-7; %min dx in broyden
        %mpar.broydeniter=200;   %max iter broyden
        %[Vsol,~,count,dist] = broyden(dist_V,V,mpar.crit,mpar.broydencritx,mpar.broydeniter);
        [~,bprimesol]=VFI_update_spline(Vsol,Y,par,mpar,grid,U,flag,solflag,I,mu,[],[],[],[],[],[],fminbnd_options);
        csol=par.w+(1+par.r)*grid.b(:)-bprimesol;
        disp(['Flag ',num2str(flag),': Number of evaluations: ',num2str(count-1),'. Time: ',num2str(toc,2), ' sec.']);
        
        if     flag==1, U.ro=Vsol; bprime.ro=bprimesol;c.ro=csol;distF.ro=dist;
        elseif flag==2, U.ry=Vsol; bprime.ry=bprimesol;c.ry=csol;distF.ry=dist;
        elseif flag==3, U.io=Vsol; bprime.io=bprimesol;c.io=csol;distF.io=dist;
        elseif flag==4, U.iy=Vsol; bprime.iy=bprimesol;c.iy=csol;distF.iy=dist;
        elseif flag==5, U.so=Vsol; bprime.so=bprimesol;c.so=csol;distF.so=dist;
        elseif flag==6, U.sy=Vsol; bprime.sy=bprimesol;c.sy=csol;distF.sy=dist;
        end
    end
    disp(' ');
    disp(['Total time in minutes: ',num2str((toc(tstart))/60,3)]);
    
    V_ini=[U.ro U.ry U.io U.iy U.so U.sy];
    save V_ini V_ini;
    
    if any((U.io-U.so)>0)
        disp('old: utility for infected larger than for susceptibles');
        return;
    end
    
    if any((U.iy-U.sy)>0)
        disp('young: utility for infected larger than for susceptibles');
        return;
    end
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Command Window Output
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    %Calc. value of life etc and print out results to report in paper
    disp(' ');
    par.cro_noepi_ini=c.ro(par.grid_b_old_assets_idx); par.cry_noepi_ini=c.ry(par.grid_b_young_assets_idx); %initial(no epi) consumption young and old; take recovered which are identical to susceptible (in no epi case)
    
    %Calc. value of life and print out results to report in paper
    muc_ro=(U.ro(par.grid_b_old_assets_idx)-par.z)^par.rho*(1-par.betta)*c.ro(par.grid_b_old_assets_idx)^(-par.rho);
    VoL_old_assets_millions=U.ro(par.grid_b_old_assets_idx)/muc_ro/1000000;
    muc_ry=(U.ry(par.grid_b_young_assets_idx)-par.z)^par.rho*(1-par.betta)*c.ry(par.grid_b_young_assets_idx)^(-par.rho);
    VoL_young_assets_millions=U.ry(par.grid_b_young_assets_idx)/muc_ry/1000000;
    
    cry_cro_ratio=c.ry(par.grid_b_young_assets_idx)/c.ro(par.grid_b_old_assets_idx);
    savings_rate_ro=(par.w+par.r*grid.b(par.grid_b_old_assets_idx)-c.ro(par.grid_b_old_assets_idx))/(par.w+par.r*grid.b(par.grid_b_old_assets_idx));
    savings_rate_ry=(par.w+par.r*grid.b(par.grid_b_young_assets_idx)-c.ry(par.grid_b_young_assets_idx))/(par.w+par.r*grid.b(par.grid_b_young_assets_idx));
    avg_savings_rate=par.sy_shr*savings_rate_ry+(1-par.sy_shr)*savings_rate_ro;
    avg_VoL=par.sy_shr*VoL_young_assets_millions+(1-par.sy_shr)*VoL_old_assets_millions;
    
    
    disp(['avg VoL, mill, initial assets    = ',num2str(avg_VoL)]);
    disp(['cry_cro_ratio, initial assets    = ',num2str(cry_cro_ratio)]);
    disp(['avg savings_rate_ry, ini assets  = ',num2str(avg_savings_rate)]);
    disp(' ');
    
    disp(['Median assets on grid: ' num2str(grid.b(par.grid_b_median_assets_idx))]);
    disp(['Young assets on grid: ' num2str(grid.b(par.grid_b_young_assets_idx))]);
    disp(['Old assets on grid: ' num2str(grid.b(par.grid_b_old_assets_idx))]);
    disp(['VoL old, ini assets, mill = ',num2str(VoL_old_assets_millions)]);
    disp(['VoL young, ini assets, mill =  ',num2str(VoL_young_assets_millions)]);
    disp(['avg VoL, ini assets, mill = ',num2str(avg_VoL)]);
    disp(['cro, ini assets = ',num2str(c.ro(par.grid_b_old_assets_idx))]);
    disp(['cry, ini assets =  ',num2str(c.ry(par.grid_b_young_assets_idx))]);
    disp(['savings_rate_ro, ini assets = ',num2str(savings_rate_ro)]);
    disp(['savings_rate_ry, ini assets =  ',num2str(savings_rate_ry)]);
    disp(['avg_savings_rate, ini assets = ',num2str(avg_savings_rate)]);
    disp(['cry_cro_ratio, ini assets =  ',num2str(cry_cro_ratio)]);
    disp(' ');
    
    
    %set pid's and pir's to begining-of-sample values, i.e. initial values
    par.pidy=par.pidy_vec_true(1);%Weekly probability of dying, young
    par.piry=par.piry_vec_true(1);%Weekly probability of recovering, young
    par.pido=par.pido_vec_true(1);%Weekly probability of dying, old
    par.piro=par.piro_vec_true(1);%Weekly probability of recovering, old
    
    %back out pi1 and pi2 from R0 and kappa
    %kappa=(par.pi1*(par.cry_noepi_ini*par.sy_shr+par.cro_noepi_ini*(1-par.sy_shr)))/(par.pi1*(par.cry_noepi_ini*par.sy_shr+par.cro_noepi_ini*(1-par.sy_shr))+par.pi2); %take weighted avg of old and young consumption
    %par.Rnot=(par.pi2+par.pi1*par.cry_noepi_ini*par.sy_shr+par.pi1*par.cro_noepi_ini*(1-par.sy_shr) ) /(par.piry*par.sy_shr+par.piro*(1-par.sy_shr)+par.pidy*par.sy_shr+par.pido*(1-par.sy_shr));
    %solve the two equations above for par.pi1 and par.pi2
    tmp=(par.cry_noepi_ini*par.sy_shr+par.cro_noepi_ini*(1-par.sy_shr));
    tmp2=(par.piry*par.sy_shr+par.piro*(1-par.sy_shr)+par.pidy*par.sy_shr+par.pido*(1-par.sy_shr));
    par.pi2=(1-par.kappa)*par.R0*tmp2;
    par.pi1=par.kappa*par.pi2/((1-par.kappa)*tmp);
    %disp(['pi1 = ',num2str(par.pi1),', pi2 = ',num2str(par.pi2)]);
    par.kappa_ini_calib=(par.pi1*(par.cry_noepi_ini*par.sy_shr+par.cro_noepi_ini*(1-par.sy_shr)))/(par.pi1*(par.cry_noepi_ini*par.sy_shr+par.cro_noepi_ini*(1-par.sy_shr))+par.pi2); %take weighted avg of old and young consumption
    par.Rnot_ini_calib=(par.pi2+par.pi1*par.cry_noepi_ini*par.sy_shr+par.pi1*par.cro_noepi_ini*(1-par.sy_shr) ) /(par.piry*par.sy_shr+par.piro*(1-par.sy_shr)+par.pidy*par.sy_shr+par.pido*(1-par.sy_shr));
    
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Start calculations for willingness to pay to avoid epi
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    %Start with the no epidemic baseline. Calculate time paths of assets,
    %consumption, PV utility etc. (I=0).
    
    ZZ=3000; %periods to simulate
    
    c_NoEpiSim.so=NaN*zeros(ZZ,1);b_NoEpiSim.so=NaN*zeros(ZZ,1);
    U_NoEpiSim.so=NaN*zeros(ZZ,1);c_NoEpiSim.sy=NaN*zeros(ZZ,1);
    b_NoEpiSim.sy=NaN*zeros(ZZ,1);U_NoEpiSim.sy=NaN*zeros(ZZ,1);
    
    b_NoEpiSim.so(1)=grid.b(par.grid_b_old_assets_idx); %initial assets, old
    b_NoEpiSim.sy(1)=grid.b(par.grid_b_young_assets_idx); %initial assets, young
    
    %notation:
    %grid.b: grid of assets, t-1
    %bprime: grid of assets, t (conditional on t-1 assets), policy function
    %for assets
    %c: policy function for consumption (budget constraint) as a function of assets, t-1 and t
    
    for jj=2:1:ZZ+1
        b_NoEpiSim.so(jj) = interp1(grid.b,bprime.so,b_NoEpiSim.so(jj-1),'linear','extrap');
        c_NoEpiSim.so(jj-1) = (par.w+(1+par.r)*b_NoEpiSim.so(jj-1)-b_NoEpiSim.so(jj)); %use budget to calc. cons.
        U_NoEpiSim.so(jj-1) = interp1(grid.b,U.so,b_NoEpiSim.so(jj-1),'linear','extrap');
        
        b_NoEpiSim.sy(jj) = interp1(grid.b,bprime.sy,b_NoEpiSim.sy(jj-1),'linear','extrap'); %get path of assets given initial assets
        c_NoEpiSim.sy(jj-1) = (par.w+(1+par.r)*b_NoEpiSim.sy(jj-1)-b_NoEpiSim.sy(jj)); %use budget to calc. cons.
        U_NoEpiSim.sy(jj-1) = interp1(grid.b,U.sy,b_NoEpiSim.sy(jj-1),'linear','extrap');
    end
    
    
    %further calculations for the willingness to pay to avoid the epi are done
    %toward the end of the code
    
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%SIMULATION OF EPIDEMIC                     %%%%%%%%%%%%%%%%%%
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %above we have solved for the value functions for I=0, i.e. no infections.
    %Now we simulate a perfect foresight simulation starting with the terminal
    %period and recursing backbard in time. In each period, agents use the next
    %period value function form the previous period. The code also allows
    %to set a switch such that the second and third waves are unexpected
    %(i.e. no perfect foresight)
    
    %Calculate Sy, So etc which is needed later to calulate per capita cons
    %initial conditions
    Iy=Ivec'*par.sy_shr;Io=Ivec'*(1-par.sy_shr);
    
    Dy(1)=0;Do(1)=0;Ry(1)=0;Ro(1)=0;
    
    %iterate on SIR model equations
    for j=1:1:numel(Ivec)-1
        Ry(j+1,1)=Ry(j)*(1-par.deltay-par.nu)+par.piry_vec_true(j)*Iy(j)*(1-par.deltay-par.nu);
        Ro(j+1,1)=Ro(j)*(1-par.deltao)+par.nu*Ry(j)+par.piry_vec_true(j)*Iy(j)*par.nu+par.piro_vec_true(j)*Io(j)*(1-par.deltao);
        
        Dy(j+1,1)=Dy(j)+par.pidy_vec_true(j)*Iy(j);
        Do(j+1,1)=Do(j)+par.pido_vec_true(j)*Io(j);
    end
    Sy=par.sy_shr-Ry-Dy-Iy;
    So=(1-par.sy_shr)-Ro-Do-Io;
    
    
    %initialize value functions etc over time.
    Umat.ro=zeros(numel(U.ro(mpar.grid_sim_idx)),numel(Ivec));Umat.ry=zeros(numel(U.ro(mpar.grid_sim_idx)),numel(Ivec));
    Umat.io=zeros(numel(U.ro(mpar.grid_sim_idx)),numel(Ivec));Umat.iy=zeros(numel(U.ro(mpar.grid_sim_idx)),numel(Ivec));
    Umat.so=zeros(numel(U.ro(mpar.grid_sim_idx)),numel(Ivec));Umat.sy=zeros(numel(U.ro(mpar.grid_sim_idx)),numel(Ivec));
    
    Umat.ro(:,end)=U.ro(mpar.grid_sim_idx);Umat.ry(:,end)=U.ry(mpar.grid_sim_idx);
    Umat.io(:,end)=U.io(mpar.grid_sim_idx);Umat.iy(:,end)=U.iy(mpar.grid_sim_idx);
    Umat.so(:,end)=U.so(mpar.grid_sim_idx);Umat.sy(:,end)=U.sy(mpar.grid_sim_idx);
    
    bprimemat.ro=zeros(numel(bprime.ro(mpar.grid_sim_idx)),numel(Ivec));bprimemat.ry=zeros(numel(bprime.ro(mpar.grid_sim_idx)),numel(Ivec));
    bprimemat.io=zeros(numel(bprime.ro(mpar.grid_sim_idx)),numel(Ivec));bprimemat.iy=zeros(numel(bprime.ro(mpar.grid_sim_idx)),numel(Ivec));
    bprimemat.so=zeros(numel(bprime.ro(mpar.grid_sim_idx)),numel(Ivec));bprimemat.sy=zeros(numel(bprime.ro(mpar.grid_sim_idx)),numel(Ivec));
    
    bprimemat.ro(:,end)=bprime.ro(mpar.grid_sim_idx);bprimemat.ry(:,end)=bprime.ry(mpar.grid_sim_idx);
    bprimemat.io(:,end)=bprime.io(mpar.grid_sim_idx);bprimemat.iy(:,end)=bprime.iy(mpar.grid_sim_idx);
    bprimemat.so(:,end)=bprime.so(mpar.grid_sim_idx);bprimemat.sy(:,end)=bprime.sy(mpar.grid_sim_idx);
    
    solflag=0; %simulate rather than solve the value function for i=0 for all t.
    
    %data consumption young and old (see xls file) (monthly; starts in
    %march 20; goes thru april 21)
    dat_young=100*[-0.127	-0.322	-0.211	-0.012	0.046	0.046	0.024	-0.084	-0.059	-0.127	-0.296	-0.212	-0.083	-0.072];
    dat_old=100*[  -0.136	-0.419	-0.254	-0.074	-0.027	-0.040	-0.031	-0.119	-0.109	-0.184	-0.363	-0.254	-0.099	-0.085];
    
    %standard deviations consumption young and old (see xls file)
    dat_stderr_young=[0.004577447	0.005042758	0.004922426	0.004931662	0.005567094	0.005077663	0.005096216	0.005207857	0.005260694	0.005419005	0.00568363	0.005701155	0.005936621	0.005972153];
    dat_stderr_old  =[0.004376884	0.005031653	0.004743569	0.004813666	0.005235872	0.004921257	0.004925781	0.005106109	0.005132598	0.005329912	0.005757338	0.005721398	0.005909867	0.005981088];
    par.dat_stderr_young=dat_stderr_young;
    par.dat_stderr_old=dat_stderr_old;
    
    %precision matrix
    VVvec=[dat_stderr_young(1:end)';dat_stderr_old(1:end)'].^2;
    VVmat=diag(VVvec);
    invVVmat=inv(VVmat);
    logdetVVmat=log(det(VVmat));
    par.invVVmat=invVVmat;
    par.logdetVVmat=logdetVVmat;
    
    %model implied data
    dt_cons = datetime('01/Mar/2020') : datetime('15/May/2021');
    
    %estimated parameter names
    par.guess_name=[{'mu_scale  '};{'kappa     '}];
    
    %initial guess
    guess=[ 0.1
        0.1
        ];
    
    %these are lower and upper bounds for fmincon mode-finder;
    %all priors are uniform -- these are the LB and UB for the priors too
    par.LB=[0;0];
    par.UB=[1;1];
    
    
    if mpar.do_estimation==1
        tic
        disp(' ');
        disp('Starting Estimation');
        disp('-------------------');
        
        if mpar.use_parallel==0
            %opts_fmincon=optimoptions('fmincon','Display','iter','ConstraintTolerance',1e-7,'StepTolerance',1e-7,'TolFun',1e-7,'MaxFunctionEvaluations',2,'UseParallel',false,'FiniteDifferenceStepSize',1e-7,'Algorithm','sqp'); %options for fmincon w/o parallel comp.
            opts_fmincon=optimoptions('fmincon','Display','iter','MaxFunctionEvaluations',mpar.est_maxiter,'UseParallel',false,'algorithm','sqp','ConstraintTolerance',1e-12,'StepTolerance',1e-12,'TolFun',1e-12,'FiniteDifferenceStepSize',1e-4); %options for fmincon w/o parallel comp.
        elseif mpar.use_parallel==1
            %opts_fmincon=optimoptions('fmincon','Display','iter','ConstraintTolerance',1e-12,'StepTolerance',1e-12,'TolFun',1e-12,'MaxFunctionEvaluations',3000,'UseParallel',true,'FiniteDifferenceStepSize',1e-8,'Algorithm','sqp');
            opts_fmincon=optimoptions('fmincon','Display','iter','MaxFunctionEvaluations',mpar.est_maxiter,'UseParallel',true,'algorithm','sqp','ConstraintTolerance',1e-12,'StepTolerance',1e-12,'TolFun',1e-12,'FiniteDifferenceStepSize',1e-4); %options for fmincon w/o parallel comp.
        end
        
        if mpar.load_last_estimation_step==1
            disp('Loading last mode finder step');
            load last_estimation_step
            guess=sol;
        end
        
        if mpar.start_mode_finder==1
            
            [sol,post_value,exitflag,OUTPUT,LAMBDA,GRAD,HESSIAN_FMINCON] = fmincon(@simulate_model,guess,[],[],[],[],par.LB,par.UB,[],opts_fmincon,Ivec,muvec,par,U,Y,mpar,grid,solflag,Umat,bprimemat,Sy,Ry,Iy,So,Ro,Io,newdeaths_weekly,dat_young,dat_old,dt_cons,fminbnd_options,c_NoEpiSim);
            
            disp('Computing Hessian: this may take a while...');
            disp('-------------------------------------------');
            hessian_mat=reshape(hessian('simulate_model',sol,eps,Ivec,muvec,par,U,Y,mpar,grid,solflag,Umat,bprimemat,Sy,Ry,Iy,So,Ro,Io,newdeaths_weekly,dat_young,dat_old,dt_cons,fminbnd_options,c_NoEpiSim),numel(sol),numel(sol));
            %hessian_mat=HESSIAN_FMINCON;
            
            invhessian=inv(hessian_mat);  %inverse hessian
            
            save last_estimation_step sol hessian_mat invhessian;%save solution for possible use in subsequent maximization
        else
            load last_estimation_step;
        end
        
        %get parameter at mode
        par.muvec_scale=sol(1);  %scaling factor containment measure
        par.kappa=sol(2);        %kappa, avg. share of cons-based infections
        
        
        disp(' ');
        disp('Simulating with estimated parameters');
        disp('------------------------------------');
        [neglogpost,cmat,comat,cymat,cons_monthly,pidy_vec_beliefs,pido_vec_beliefs,par_pi1,par_pi2,bprimemat,Umat]=simulate_model(sol,Ivec,muvec,par,U,Y,mpar,grid,solflag,Umat,bprimemat,Sy,Ry,Iy,So,Ro,Io,newdeaths_weekly,dat_young,dat_old,dt_cons,fminbnd_options,c_NoEpiSim);
        par.pi1=par_pi1;par.pi2=par_pi2;
        
        %parameter posterior mode and posterior value
        para_mode=sol;
        logpost_mode=-neglogpost;
        
        %calculate laplace
        Laplace = 1/2*numel(sol)*log(2*pi) + (logpost_mode)- 1/2*log(det(invhessian))
        
        disp(' ');
        disp('Estimated Parameter(s)');
        disp('----------------------');
        for iii=1:1:numel(sol)
            disp([par.guess_name(iii),sol(iii)])
        end
        
        %standard deviations and parameter correlations
        [para_standard_deviations,para_correlation_matrix]=Cov2Corr(invhessian);
        para_standard_deviations=para_standard_deviations'
        
        % Posterior correlations
        ExtremeCorrBound=0.15;
        if  ~isnan(ExtremeCorrBound)
            %[Stds,CorrMat]=Cov2Corr(InvHessian); % Computing correlations and standard deviations from InvHessian.
            tril_para_correlation_matrix=tril(para_correlation_matrix,-1);
            [RowIndex,ColIndex]=find(abs(tril_para_correlation_matrix)>ExtremeCorrBound);
            ExtremeCorrParams=cell(length(RowIndex),3);
            for i=1:length(RowIndex)
                ExtremeCorrParams{i,1}=char(par.guess_name(RowIndex(i)));
                ExtremeCorrParams{i,2}=char(par.guess_name(ColIndex(i)));
                ExtremeCorrParams{i,3}=tril_para_correlation_matrix(RowIndex(i),ColIndex(i));
            end
        end
        disp(' ');
        disp(['Correlations of Parameters (at Posterior Mode) > abs(',num2str(ExtremeCorrBound),')']);
        disp(ExtremeCorrParams)
        
        %prior posterior plots
        figure('name','Prior-Posterior (Laplace)');
        ia=3;ib=3;
        trunc = 1e-12; steps = 2^15-1;
        
        for zz=1:1:numel(sol)
            a = sol(zz);b = para_standard_deviations(zz);infbound = norminv(trunc,a,b);
            supbound = norminv(1-trunc,a,b); stepsize = (supbound-infbound)/steps;
            gg = infbound:stepsize:supbound;
            post_dens = normpdf(gg,a,b);
            
            if zz==1 %muscale -- uniform prior
                gg_prior=par.LB(zz):stepsize:par.UB(zz);
                prior_dens=unifpdf(gg_prior,par.LB(zz),par.UB(zz));
            elseif zz==2 %kappa  -- uniform prior
                gg_prior=par.LB(zz):stepsize:par.UB(zz);
                prior_dens=unifpdf(gg_prior,par.LB(zz),par.UB(zz));
            end
            subplot(ia,ib,zz)
            plot(gg_prior,prior_dens,'k--','linewidth',1.5); hold on
            plot(gg,post_dens,'r-','linewidth',2); hold on; vline(sol(zz));
            axis tight
            title(char(par.guess_name(zz)),'Interpreter','none');
        end
        suptitle('Priors vs. Posteriors (Laplace Approx.)');
        legend1=legend('Prior','Posterior (Laplace)');
        set(legend1,...
            'Position',[0.418594533212181 0.906613971512973 0.174041297935103 0.0217391304347826],...
            'Orientation','horizontal');
        legend boxoff;orient landscape;
        print -dpdf -fillpage prior_posterior_laplace
        
        %check plots
        if mpar.do_checkplots
            figure('name','Check Plots');
            disp('Computing checkplots. This may take a while...');
            ia=3;ib=3;
            for zz=1:1:numel(sol)
                check_grid=linspace(sol(zz)-2*para_standard_deviations(zz),sol(zz)+2*para_standard_deviations(zz),mpar.grdpts);
                for ww=1:1:numel(check_grid)
                    check_sol=sol;check_sol(zz)=check_grid(ww);
                    [neglogpost_check]=simulate_model(check_sol,Ivec,muvec,par,U,Y,mpar,grid,solflag,Umat,bprimemat,Sy,Ry,Iy,So,Ro,Io,newdeaths_weekly,dat_young,dat_old,dt_cons,fminbnd_options,c_NoEpiSim);
                    logpost_check(ww)=-neglogpost_check;
                end
                subplot(ia,ib,zz)
                plot(check_grid,logpost_check,'k-','linewidth',2); hold on;vline(sol(zz));
                title(char(par.guess_name(zz)),'Interpreter','none');
            end
            suptitle('Log-Posterior Slices (Check plot)');
            orient landscape; print -dpdf -fillpage checkplots
        end
        
        disp(['Time in hrs: ',num2str(toc/3600)]);
        
        
        if mpar.do_mcmc==1
            disp('Starting MCMC. This may take a while....');
            mcmc;
        elseif mpar.do_load_mcmc_results==1
            disp('Loading MCMC results.');
            load mcmc_results;
        else
            disp('Taking posterior mode based on optimizer...');
            post_mode_mcmc=sol;
        end
        
        par.muvec_scale= post_mode_mcmc(1);             %scaling factor containment measure
        par.kappa=post_mode_mcmc(2);                    %kappa, avg. cons-based infections
        
        disp('Simulating');
        [neglogpost,cmat,comat,cymat,cons_monthly,pidy_vec_beliefs,pido_vec_beliefs,par_pi1,par_pi2,bprimemat,Umat]=simulate_model(post_mode_mcmc,Ivec,muvec,par,U,Y,mpar,grid,solflag,Umat,bprimemat,Sy,Ry,Iy,So,Ro,Io,newdeaths_weekly,dat_young,dat_old,dt_cons,fminbnd_options,c_NoEpiSim);
        par.pi1=par_pi1;par.pi2=par_pi2;
    else
        disp(' ');
        disp('Simulating with given parameters');
        disp('--------------------------------');
        tic
        %par.guess_name=[{'mu_scale   '};{'kappa     '}];
        para=[
            0.525754629721827
            0.069099320356513
            ];
        
        
        
        %profile on;
        tic;
        [neglogpost,cmat,comat,cymat,cons_monthly,pidy_vec_beliefs,pido_vec_beliefs,par_pi1,par_pi2,bprimemat,Umat]=simulate_model(para,Ivec,muvec,par,U,Y,mpar,grid,solflag,Umat,bprimemat,Sy,Ry,Iy,So,Ro,Io,newdeaths_weekly,dat_young,dat_old,dt_cons,fminbnd_options,c_NoEpiSim);
        toc;
        par.pi1=par_pi1;par.pi2=par_pi2;
        %profile viewer;
    end
    
    %put data in time table format
    tty_data = timetable(cons_monthly.Cy.Time(1:end-1),dat_young','VariableNames',{'Consy_data'});
    tto_data = timetable(cons_monthly.Cy.Time(1:end-1),dat_old','VariableNames',{'Conso_data'});
    
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Willingness to pay to avoid epidemic. Continue calculations here.
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    %Next, calculate time paths for assets, consumption, PV utility etc with epi
    
    ZZ=numel(Umat.so(1,:)); %periods to simulate
    
    c_EpiSim.so=NaN*zeros(ZZ,1);b_EpiSim.so=NaN*zeros(ZZ,1);
    U_EpiSim.so=NaN*zeros(ZZ,1);c_EpiSim.sy=NaN*zeros(ZZ,1);
    b_EpiSim.sy=NaN*zeros(ZZ,1);U_EpiSim.sy=NaN*zeros(ZZ,1);
    
    c_EpiSim.io=NaN*zeros(ZZ,1);b_EpiSim.io=NaN*zeros(ZZ,1);
    U_EpiSim.io=NaN*zeros(ZZ,1);c_EpiSim.iy=NaN*zeros(ZZ,1);
    b_EpiSim.iy=NaN*zeros(ZZ,1);U_EpiSim.iy=NaN*zeros(ZZ,1);
    
    c_EpiSim.ro=NaN*zeros(ZZ,1);b_EpiSim.ro=NaN*zeros(ZZ,1);
    U_EpiSim.ro=NaN*zeros(ZZ,1);c_EpiSim.ry=NaN*zeros(ZZ,1);
    b_EpiSim.ry=NaN*zeros(ZZ,1);U_EpiSim.ry=NaN*zeros(ZZ,1);
    
    b_EpiSim.so(1)=grid.b(par.grid_b_old_assets_idx); %initial assets, old
    b_EpiSim.sy(1)=grid.b(par.grid_b_young_assets_idx); %initial assets, young
    b_EpiSim.io(1)=grid.b(par.grid_b_old_assets_idx); %initial assets, old
    b_EpiSim.iy(1)=grid.b(par.grid_b_young_assets_idx); %initial assets, young
    b_EpiSim.ro(1)=grid.b(par.grid_b_old_assets_idx); %initial assets, old
    b_EpiSim.ry(1)=grid.b(par.grid_b_young_assets_idx); %initial assets, young
    
    %notation:
    %grid.b: grid of assets, t-1
    %bprime: grid of assets, t (conditional on t-1 assets), policy function
    %for assets
    %c: policy function for consumption as a function of assets, t-1
    
    for jj=2:1:ZZ+1
        b_EpiSim.so(jj) = interp1(grid.b(mpar.grid_sim_idx),bprimemat.so(:,jj-1),b_EpiSim.so(jj-1),'linear','extrap'); %get path of assets given initial assets
        c_EpiSim.so(jj-1) = (par.w+(1+par.r)*b_EpiSim.so(jj-1)-b_EpiSim.so(jj)); %use budget to calc. cons.
        
        b_EpiSim.sy(jj) = interp1(grid.b(mpar.grid_sim_idx),bprimemat.sy(:,jj-1),b_EpiSim.sy(jj-1),'linear','extrap'); %get path of assets given initial assets
        c_EpiSim.sy(jj-1) = (par.w+(1+par.r)*b_EpiSim.sy(jj-1)-b_EpiSim.sy(jj)); %use budget to calc. cons.
        
        b_EpiSim.iy(jj) = interp1(grid.b(mpar.grid_sim_idx),bprimemat.iy(:,jj-1),b_EpiSim.iy(jj-1),'linear','extrap'); %get path of assets given initial assets
        c_EpiSim.iy(jj-1) = (par.w+(1+par.r)*b_EpiSim.iy(jj-1)-b_EpiSim.iy(jj)); %use budget to calc. cons.
        
        b_EpiSim.io(jj) = interp1(grid.b(mpar.grid_sim_idx),bprimemat.io(:,jj-1),b_EpiSim.io(jj-1),'linear','extrap'); %get path of assets given initial assets
        c_EpiSim.io(jj-1) = (par.w+(1+par.r)*b_EpiSim.io(jj-1)-b_EpiSim.io(jj)); %use budget to calc. cons.
        
        b_EpiSim.ry(jj) = interp1(grid.b(mpar.grid_sim_idx),bprimemat.ry(:,jj-1),b_EpiSim.ry(jj-1),'linear','extrap'); %get path of assets given initial assets
        c_EpiSim.ry(jj-1) = (par.w+(1+par.r)*b_EpiSim.ry(jj-1)-b_EpiSim.ry(jj)); %use budget to calc. cons.
        
        b_EpiSim.ro(jj) = interp1(grid.b(mpar.grid_sim_idx),bprimemat.ro(:,jj-1),b_EpiSim.ro(jj-1),'linear','extrap'); %get path of assets given initial assets
        c_EpiSim.ro(jj-1) = (par.w+(1+par.r)*b_EpiSim.ro(jj-1)-b_EpiSim.ro(jj)); %use budget to calc. cons.
        
        %calc. utility of susceptible old and young for willingness to pay
        %calculations
        U_EpiSim.so(jj-1) = interp1(grid.b(mpar.grid_sim_idx),Umat.so(:,jj-1),b_EpiSim.so(jj-1),'linear','extrap');
        U_EpiSim.sy(jj-1) = interp1(grid.b(mpar.grid_sim_idx),Umat.sy(:,jj-1),b_EpiSim.sy(jj-1),'linear','extrap');
        
    end
    
    
    %print out results and calculate willingness to pay to avoid epi in terms
    %of assets
    
    %old
    PV_Uso_NoEpi=U_NoEpiSim.so(1)
    PV_Uso_Epi=U_EpiSim.so(1)
    assets_so_NoEpiCounterfactual = interp1(U.so,grid.b,PV_Uso_Epi,'linear','extrap')
    PV_Uso_NoEpi_Counterfactual = interp1(grid.b,U.so,assets_so_NoEpiCounterfactual,'linear','extrap')
    percent_of_income_cons_to_give_up_so=100*(par.old_assets-assets_so_NoEpiCounterfactual)/(52*par.w+par.intinc_old)
    
    %young
    PV_Usy_NoEpi=U_NoEpiSim.sy(1)
    PV_Usy_Epi=U_EpiSim.sy(1)
    assets_sy_NoEpiCounterfactual = interp1(U.sy,grid.b,PV_Usy_Epi,'linear','extrap')
    PV_Usy_NoEpi_Counterfactual = interp1(grid.b,U.sy,assets_sy_NoEpiCounterfactual,'linear','extrap')
    percent_of_income_cons_to_give_up_sy=100*(par.young_assets-assets_sy_NoEpiCounterfactual)/(52*par.w+par.intinc_young)
    
    
    if percent_of_income_cons_to_give_up_sy<0
        disp('percent_of_income_cons_to_give_up_sy<0: something is wrong!');
    end
    if percent_of_income_cons_to_give_up_so<0
        disp('percent_of_income_cons_to_give_up_so<0: something is wrong!');
    end
    save all_results;
else
    load all_results;
end

%printout current parameters
%par.kappa_ini=(par.pi1*(par.cry_noepi_ini*par.sy_shr+par.cro_noepi_ini*(1-par.sy_shr)))/(par.pi1*(par.cry_noepi_ini*par.sy_shr+par.cro_noepi_ini*(1-par.sy_shr))+par.pi2); %take weighted avg of old and young consumption
%par.Rnot_ini=(par.pi2+par.pi1*par.cry_noepi_ini*par.sy_shr+par.pi1*par.cro_noepi_ini*(1-par.sy_shr) ) /(par.piry*par.sy_shr+par.piro*(1-par.sy_shr)+par.pidy*par.sy_shr+par.pido*(1-par.sy_shr));
%par


load last_estimation_step

%get parameter at mode
par.muvec_scale=sol(1);  %scaling factor containment measure
par.kappa=sol(2);        %kappa, avg. share of cons-based infections


disp(' ');
disp('Simulating with estimated parameters');
disp('------------------------------------');
[neglogpost,cmat,comat,cymat,cons_monthly,pidy_vec_beliefs,pido_vec_beliefs,par_pi1,par_pi2,bprimemat,Umat]=simulate_model(sol,Ivec,muvec,par,U,Y,mpar,grid,solflag,Umat,bprimemat,Sy,Ry,Iy,So,Ro,Io,newdeaths_weekly,dat_young,dat_old,dt_cons,fminbnd_options,c_NoEpiSim);
par.pi1=par_pi1;par.pi2=par_pi2;

%parameter posterior mode and posterior value
para_mode=sol;
logpost_mode=-neglogpost;

%calculate laplace
Laplace = 1/2*numel(sol)*log(2*pi) + (logpost_mode)- 1/2*log(det(invhessian))

%plot all results
plot_results;


